Skip to content

(Draft)[Main][feat] Support overlapping A2A Combine backprop with wgrad GEMM#3792

Closed
Wohox wants to merge 7 commits intoNVIDIA:mainfrom
Wohox:pingtian/support_backawrd_dw_for_fsdp_main
Closed

(Draft)[Main][feat] Support overlapping A2A Combine backprop with wgrad GEMM#3792
Wohox wants to merge 7 commits intoNVIDIA:mainfrom
Wohox:pingtian/support_backawrd_dw_for_fsdp_main

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Mar 11, 2026

What does this PR do ?

PR for dev: #3766

Problem

In MoE models, the expert weight gradient (wgrad) computation during backward is serialized on the main CUDA stream. This blocks the data gradient (dgrad) from flowing to earlier layers until the expert wgrad finishes, even though there is no data dependency between them. The result is wasted GPU cycles — earlier layers' backward pass sits idle waiting for expert wgrad to complete.

With FSDP, this is further compounded because the gradient reduce-scatter for expert parameters is also blocked on the same critical path.

Solution

This PR introduces a new flag --delay-wgrad-compute-for-te-grouped-gemm that separates the expert wgrad computation from the main backward stream:

  1. Two autograd functions are inserted into the MoE layer's forward graph:

    • _RecordExpertDgradCompletion — placed before the expert computation; during backward, it records a CUDA event once the expert dgrad is done.
    • _RegisterDelayedWgradForExperts — placed at the dispatch boundary; during backward, it waits on the dgrad event, then launches backward_dw() on a dedicated CUDA stream, and synchronizes back to the main stream before proceeding.
  2. FSDP integration — When used with MegatronFSDP, expert parameters are marked with _fsdp_delay_grad_reduce = True so the normal post-accumulate-grad hook skips them. A callback is registered via register_process_expert_grads_fn() that triggers the FSDP reduce-scatter for expert parameters only after the delayed wgrad computation completes.

  3. TE GroupedLinear is configured with delay_wgrad_compute=True, which tells Transformer Engine to skip wgrad during the normal autograd backward and instead wait for an explicit backward_dw() call.

How to enable

--delay-wgrad-compute-for-te-grouped-gemm

Requirements:

  • Transformer Engine >= 2.3.0
  • moe_grouped_gemm enabled (not legacy grouped gemm)
  • Mutually exclusive with --delay-wgrad-compute (the existing A2A-overlap-based delay)
  • Mutually exclusive with --overlap-moe-expert-parallel-comm

Works with both FSDP and 3-D parallelism (TP/EP/PP).

What is achieved

The expert wgrad computation runs on a separate CUDA stream, overlapping with the EP communication within the same transformer layer. This reduces the wall-clock time of the backward pass without changing numerical results — the feature is bit-exact with the non-delayed baseline (verified by unit tests comparing per-step losses and final weights over multiple optimizer steps).

Changes

File Description
megatron/core/model_parallel_config.py New config flag delay_wgrad_compute_for_te_grouped_gemm
megatron/core/transformer/transformer_config.py Validation assertions for the new flag
megatron/core/transformer/moe/moe_layer.py Autograd functions for delayed wgrad + dedicated CUDA stream/event + register_process_expert_grads_fn callback
megatron/core/extensions/transformer_engine.py Pass delay_wgrad_compute=True to TE GroupedLinear when the new flag is set
megatron/core/distributed/fsdp/.../megatron_fsdp.py FSDP hook to defer reduce-scatter for expert params and trigger it after delayed wgrad
tests/unit_tests/a2a_overlap/test_delay_wgrad_compute.py Unit tests covering basic, shared-expert, multi-layer, and FSDP scenarios

Test plan

  • Unit test: test_delay_wgrad_compute_for_te_grouped_gemm — full-model training loop (forward → backward → optimizer) comparing delayed vs. non-delayed across num_layers × shared_experts × dispatcher_type × fp8_flag
  • Unit test: test_delay_wgrad_compute_for_te_grouped_gemm_with_fsdp — same comparison with MegatronFSDP wrapping (fully_shard_model + fully_shard_optimizer), verifying the deferred reduce-scatter path

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Wohox Wohox changed the title Support Delay wgrad gemm (Draft)[Main][feat] Support Delay wgrad gemm Mar 11, 2026
@Wohox Wohox changed the title (Draft)[Main][feat] Support Delay wgrad gemm (Draft)[Main][feat] Support overlapping A2A Combine backprop with wgrad GEMM Mar 11, 2026
@Wohox
Copy link
Contributor Author

Wohox commented Mar 11, 2026

/claude review

@claude
Copy link
Contributor

claude bot commented Mar 11, 2026

Review

The approach is solid — separating expert wgrad onto a dedicated CUDA stream to overlap with the dispatch backward (A2A combine backprop) is well-structured with the two autograd function design.

One inline suggestion on shared_experts=True in the delayed wgrad backward call.

Missing test coverage

There are no tests for the new delay_wgrad_compute_for_te_grouped_gemm feature. The only test file change (tests/unit_tests/a2a_overlap/utils.py) adds get_valid_flex_dispatcher_backends() which appears unrelated. Given the complexity of CUDA stream synchronization here, at minimum a test verifying:

  1. The config validation assertions (mutual exclusivity with delay_wgrad_compute, TE version check, incompatibility with overlap_moe_expert_parallel_comm).
  2. Numerical correctness — that a forward+backward pass with delay_wgrad_compute_for_te_grouped_gemm=True produces the same gradients as without it.

Wohox and others added 2 commits March 10, 2026 19:09
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
@Wohox
Copy link
Contributor Author

Wohox commented Mar 11, 2026

/claude review

Copy link
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Light review — overall this looks well-structured. The CUDA stream synchronization pattern is correct and the FSDP integration is clean.

One observation on test parametrization: test_delay_wgrad_compute_for_te_grouped_gemm and the FSDP variant parametrize flex_dispatcher_backend independently of dispatcher_type, which creates redundant test combinations (e.g., dispatcher_type="alltoall" × all flex backends) where the backend parameter is unused. This multiplies CI cost without adding coverage. Consider either collapsing these into a single parametrize that yields valid (dispatcher_type, backend) pairs, or using pytest.mark.parametrize with indirect to skip invalid combos.

@Wohox
Copy link
Contributor Author

Wohox commented Mar 11, 2026

/claude review

@claude
Copy link
Contributor

claude bot commented Mar 11, 2026

Light review — one bug found, otherwise looks good.

Bug: In megatron_fsdp.py line 765, getattr(module.mlp, '_process_expert_grads_fn') will raise AttributeError for modules that don't have a .mlp attribute (e.g., embedding layers). This hook fires for every module in _pre_forward_param_unshard, not just TransformerLayers. See inline comment for a suggested fix.

Everything else looks solid — the CUDA stream synchronization logic is correct, the mutual exclusivity assertion between delay_wgrad_compute and delay_wgrad_compute_for_te_grouped_gemm is consistent with how they're combined in transformer_engine.py, and the test coverage is thorough (multi-step loss + weight comparison, with and without FSDP).

@Wohox
Copy link
Contributor Author

Wohox commented Mar 11, 2026

/claude review

Comment on lines +662 to +668
wgrad_stream.wait_event(event)
with torch.cuda.stream(wgrad_stream):
with torch.cuda.nvtx.range("delayed_expert_wgrad"):
module.backward_dw(routed_experts=True, shared_experts=False)
event.record(wgrad_stream)

torch.cuda.current_stream().wait_event(event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: When moe_latent_size is configured, backward_dw() internally launches fc2_latent_proj.backward_dw() on comm_stream (line 585-586). However, the event here is only recorded on wgrad_stream, so torch.cuda.current_stream().wait_event(event) on line 668 does not wait for the comm_stream work to finish. This could lead to a data race where the main stream proceeds (e.g., starts the next iteration) before the latent projection weight gradients are fully computed.

You'd need to either:

  1. Also synchronize comm_stream back to wgrad_stream (or main stream) before recording the event, or
  2. Add a validation assertion that delay_wgrad_compute_for_te_grouped_gemm and moe_latent_size are mutually exclusive (if that combination isn't intended to be supported yet).

@Wohox
Copy link
Contributor Author

Wohox commented Mar 11, 2026

/claude review

Copy link
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall — clean separation of the delayed wgrad path into two autograd functions, and the FSDP integration is well-structured.

One potential bug flagged as an inline comment: when moe_latent_size is configured, backward_dw() dispatches work to comm_stream which is not synchronized before the event is recorded on wgrad_stream. This could cause a data race if the two features are combined.

Comment on lines +265 to +268
"""Delay the weight gradient computation for TE Grouped GEMM MoE experts.
When enabled with FSDP, the expert weight gradients are computed on a separate
CUDA stream after the data gradients finish, allowing overlap of wgrad compute
with the backward pass of earlier layers. The FSDP gradient reduce-scatter for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says the wgrad overlaps "with the backward pass of earlier layers," but looking at the implementation in _RegisterDelayedWgradForExperts.backward, the main backward stream synchronizes (current_stream().wait_event(event)) before returning — so earlier layers' backward cannot start until wgrad finishes. The actual overlap is between the wgrad computation and the A2A combine backward (dispatch backward) within the same layer.

Suggested change
"""Delay the weight gradient computation for TE Grouped GEMM MoE experts.
When enabled with FSDP, the expert weight gradients are computed on a separate
CUDA stream after the data gradients finish, allowing overlap of wgrad compute
with the backward pass of earlier layers. The FSDP gradient reduce-scatter for
"""Delay the weight gradient computation for TE Grouped GEMM MoE experts.
When enabled, the expert weight gradients are computed on a separate
CUDA stream after the data gradients finish, allowing overlap of wgrad compute
with the A2A combine communication within the same MoE layer. When used with
FSDP, the gradient reduce-scatter for expert parameters is deferred until the
delayed wgrad computation completes.
This requires transformer_engine with GroupedLinear support (TE >= 2.3.0).

@@ -0,0 +1,202 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: copyright year should be 2026.

Suggested change
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

@Wohox Wohox closed this Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant